import json
import time
from typing import Dict, Any
from openai import OpenAI
import os


client = OpenAI(
    api_key="EMPTY",
    base_url="http://127.0.0.1:8021/v1"
)

def preprocess(item: Dict[str, Any]) -> Dict[str, Any]:
    return item

def construct_prompt(item: Dict[str, Any]) -> str:
    question = item.get("question", "")
    answer = item.get("answer", "")
    ss = item.get("reasoning_sentences", [])
    # index = ss.index(item.get("selected_reasoning_sentence", ""))
    k = 1024
    # position = [index * k, (index + 1) * k]

    visible_tokens = item.get("sampled_token_text", [])
    wiki_docs = item.get("retrieved_rag_docs", [])

    prompt = (
        "I previously generated an embedding from a passage of text, but I lost the original text.\n"
        "This hidden passage is located between two given segments, the <Question> and the <Answer>, and represents a reasoning process from the <Question> to the <Answer>. "
        "It is semantically related to both.\n"
        # f"The length of the hidden text is approximately {k} tokens.\nThe hidden text corresponds to tokens from position {position[0]} to {position[1]} in the reasoning process.\n\n"
        f"<Question>:\n{question}\n<Question>\n\n"
        f"<Answer>:\n{answer}\n<Answer>\n\n"
        f"Some visible tokens that were part of the hidden passage include:\n{', '.join(visible_tokens)}\n\n"
        "I used the embedding of the hidden passage to retrieve some relevant documents from Wikipedia. Here they are:\n"
    )

    for i, doc in enumerate(wiki_docs):
        prompt += f"[Doc {i+1}]\n{doc.strip()}\n\n"

    prompt += (
        "Now, based on all this information, please help me recover the most likely content of the hidden passage.\n"
        "Return your answer in the following JSON format:\n"
        "```json\n{\n"
        '  "recovered_text": "<your reconstructed passage here>"\n'
        "}```"
    )

    return prompt

def record_result(item: Dict[str, Any], result: str) -> None:
    item["openai_response"] = result
    return item


def call_openai_api(prompt: str, model="/workspace/0407_nips/0514/openai/Qwen2.5-72B-Instruct", temperature=0.7, max_tokens=8192) -> str:
    # try:
    response = client.chat.completions.create(
        model="/workspace/0407_nips/0514/openai/Qwen2.5-72B-Instruct",
        messages=[{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt}],
        temperature=temperature,
        max_tokens=max_tokens,
    )
    return response.choices[0].message.content
    # except Exception as e:
    #     return "ERROR"


def main(input_path: str, output_path: str):
    with open(input_path, "r", encoding="utf-8") as f:
        data = json.load(f)

    if os.path.exists(output_path):
        os.remove(output_path)

    for i, item in enumerate(data):
        print(f"{i+1}/{len(data)}")
        # try:
        item = preprocess(item)
        prompt = construct_prompt(item)
        item["prompt"] = prompt
        result = call_openai_api(prompt)
        item = record_result(item, result)
        # except:
        #     item["GG"] = "GG"

        with open(output_path, "a", encoding="utf-8") as f:
            f.write(json.dumps(item, ensure_ascii=False, indent=2) + ",\n")

        time.sleep(5)

    print(f"{output_path}")


if __name__ == "__main__":
    input_file = "/workspace/0407_nips/0514/openai/data_first_samples/1024_output_test_OpenR1-Math-220k_100.json"
    output_file = "output_with_results_1024_first_samples.json"
    main(input_file, output_file)
